// Copyright © 2025 blackshirt.
// Use of this source code is governed by an MIT license
// that can be found in the LICENSE file.
//
// This module implements building block for elliptic-curve diffie-helman
// key exchange (ECDH) mechanism through curve25519 curve.
module curve25519

import crypto.rand
import crypto.internal.subtle
import crypto.ed25519.internal.edwards25519

// scalar_size is the size of the Curve25519 key
const scalar_size = 32

// point_size is the size of the Curve25519 point
const point_size = 32

// zero_point is point with 32 bytes length of zeros bytes
const zero_point = []u8{len: 32, init: u8(0x00)}

// base_point is the canonical Curve25519 generator, encoded as a byte with value 9,
// followed by 31 zero bytes
const base_point = [u8(9), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
	0, 0, 0, 0, 0, 0, 0, 0]

// PrivateKey represents Curve25519 private key
@[noinit]
pub struct PrivateKey {
mut:
	// boolean flag that tells this key should not be
	// used again when its true. its set after .free call
	done bool
	// clamped key, 32 bytes length
	key []u8
}

// new creates a new Curve25519 key using randomly generated bytes from `crypto.rand`.
@[direct_array_access]
pub fn PrivateKey.new() !&PrivateKey {
	mut bytes := rand.read(scalar_size)!
	// do bytes clamping
	clamp(mut bytes)!

	if is_zero_point(bytes) || is_base_point(bytes) {
		return error('PrivateKey.new: bytes zeros or base point')
	}

	return &PrivateKey{
		key: bytes
	}
}

// new_from_seed creates a new Curve25519 key from provided seed bytes.
@[direct_array_access]
pub fn PrivateKey.new_from_seed(seed []u8) !&PrivateKey {
	if seed.len != scalar_size {
		return error('invalid scalar size')
	}
	mut bytes := seed.clone()
	clamp(mut bytes)!

	if is_zero_point(bytes) || is_base_point(bytes) {
		return error('PrivateKey.new_from_seed: bytes was zeros or base point')
	}
	return &PrivateKey{
		key: bytes
	}
}

// public_key returns the associated public key part of the PrivateKey.
@[direct_array_access]
pub fn (mut pv PrivateKey) public_key() !&PublicKey {
	if pv.done {
		return error('your key has been marked as freed')
	}
	out := x25519(mut pv.key, base_point)!
	return &PublicKey{
		key: out
	}
}

// equal returns whether two private keys are equal.
@[direct_array_access]
pub fn (pv PrivateKey) equal(oth PrivateKey) bool {
	if pv.done || oth.done {
		panic('the key has been marked as freed')
	}
	if pv.key.len != scalar_size || oth.key.len != scalar_size {
		return false
	}
	return subtle.constant_time_compare(pv.key, oth.key) == 1
}

// x25519 performs scalar multiplication between key and point and return another bytes (point).
@[direct_array_access]
pub fn (mut pv PrivateKey) x25519(point []u8) ![]u8 {
	if pv.done {
		return error('PrivateKey has been marked as freed')
	}
	if point.len != point_size {
		return error('bad point size, should be 32')
	}
	// We reject and disallow zero-bytes point to be passed
	// and check it here as a quick exit before heavy math
	// calculation on `x25519` call
	if is_zero(point) {
		return error('x25519: get zeros point')
	}
	// even technically its possible, but we limit to unallow it
	if subtle.constant_time_compare(pv.key, point) == 1 {
		return error('pv.key identical with point')
	}
	out := x25519(mut pv.key, point)!

	return out
}

// bytes return a clone of the bytes of the underlying PrivateKey
pub fn (pv PrivateKey) bytes() ![]u8 {
	if pv.done {
		return error('PrivateKey has been marked as freed')
	}
	return pv.key.clone()
}

// free releases underlying key. Dont use the key after calling .free
@[unsafe]
pub fn (mut pv PrivateKey) free() {
	// when private key has been marked as done (freed),
	// calling free on already freed key would lead to undefined behavior.
	// so, we check it
	if pv.done {
		return
	}
	unsafe { pv.key.free() }
	// sets flag to finish
	pv.done = true
}

// PublicKey represent Curve25519 key.
@[noinit]
pub struct PublicKey {
mut:
	// 32 bytes length of scalar * point
	key []u8
}

// new_from_bytes creates a new Curve25519 public key from provided bytes.
pub fn PublicKey.new_from_bytes(bytes []u8) !&PublicKey {
	if bytes.len != point_size {
		return error('PublicKey.new: bad bytes length')
	}
	// Refers to the D.J. Bernstein, the designer of the curve25519, public key validation
	// in curve25519 is generally not needed for Diffie-Hellman key exchange.
	// See https://cr.yp.to/ecdh.html#validate
	// But there are availables suggestion to do validation on them spreads on the internet, likes
	// - blacklisting the known bad public keys
	// - check the shared value and to raise exception if it is zero.
	// - You can also bind the exchanged public keys to the shared keys, i.e.,
	//   instead of using H(abG) as the shared keys, you should use H(aG || bG || abG)
	//
	// We only, check for zeros public key
	if is_zero(bytes) {
		return error('PublicKey.new: get zeros bytes')
	}

	// otherwise, we can return it
	return &PublicKey{
		key: bytes
	}
}

// equal tells whether two public keys are equal
pub fn (pb PublicKey) equal(other PublicKey) bool {
	// different length, should not happen
	if pb.key.len != point_size || other.key.len != point_size {
		return false
	}
	return subtle.constant_time_compare(pb.key, other.key) == 1
}

// bytes return the clone of the bytes of the underlying PublicKey
pub fn (pb PublicKey) bytes() ![]u8 {
	if pb.key.len != point_size {
		return error('bad public key size')
	}
	return pb.key.clone()
}

// SharedOpts is the configuration options to `derive_shared_secret` routine
@[params]
pub struct SharedOpts {
pub mut:
	should_derive bool
	derivator     Derivator = RawDerivator{}
	drv_opts      DeriveOpts
}

// DeriveOpts is config to drive the Derivator's derive operation.
@[params]
pub struct DeriveOpts {}

// Derivator represent key derivation function
pub interface Derivator {
	// derive transforms bytes in sec into another form of bytes.
	derive(sec []u8, opt DeriveOpts) ![]u8
}

// RawDerivator was a simple derivator with no derivation behaviour.
struct RawDerivator {}

fn (rd RawDerivator) derive(sec []u8, opt DeriveOpts) ![]u8 {
	return sec
}

// derive_shared_secret derives a shared secret between two peer's
// between first private key's peer and the second PublicKey's peer.
// Its accepts SharedOpts options to advance supports for other key derivation mechanism.
//
// 6.  Diffie-Hellman with Curve25519
// See https://datatracker.ietf.org/doc/html/rfc7748#section-6
pub fn derive_shared_secret(mut local PrivateKey, remote PublicKey, opt SharedOpts) ![]u8 {
	// TODO: should this check be relaxed ?
	// check for safety
	local_pubkey := local.public_key()!
	if local_pubkey.equal(remote) {
		return error('unallowed equal public key between peer')
	}
	// The local peer generates local private key, local_privkey, generates local public key, local_pubkey.
	// and remote peer generates remote private key, ie, remote_privkey,
	// with generated remote public key, remote_pubkey.
	// Both now share shared = X25519(local_privkey, remote_pubkey) = X25519(remote_privkey, local_pubkey)
	// as a shared secret, which is then used as a key or input to a key derivation function.
	sec := local.x25519(remote.key)!
	// Internally, x25519 has builtin check for zeros result
	// but only for non base point branch on x25519_generic routine
	if is_zero(sec) {
		return error('zeroes shared secret')
	}
	// you can choose this sec as an input into other key derivator, and pass this sec
	// into provided derivator
	if opt.should_derive {
		// While the shared secret can be used directly, it's often recommended to apply
		// a key derivation function (KDF), like HKDF to derive a more robust key for cryptographic operations.
		new_sec := opt.derivator.derive(sec, opt.drv_opts)!
		if is_zero(new_sec) {
			return error('zeroes shared secret after derivation')
		}
		return new_sec
	}
	// otherwise, just return the sec as is.
	return sec
}

// x25519 returns the result of the scalar multiplication (`scalar` * `point`),
// according to RFC 7748, Section 5. scalar, point and the return value are slices of 32 bytes.
// The functions take a scalar and a `u-coordinate` as inputs and produce a `u-coordinate` as output.
// Although the functions work internally with integers, the inputs and
// outputs are 32-bytes length (for X25519).
// scalar can be generated at random, for example with `crypto.rand` and point should
// be either `base_point` or the output of another `x25519` call.
@[direct_array_access]
pub fn x25519(mut scalar []u8, point []u8) ![]u8 {
	// likes the previous comment, we add zeroes point check here
	// and reject if it happen.
	if is_zero(point) || is_zero(scalar) {
		return error('x25519: unallowed zeros/scalar point')
	}
	mut dst := []u8{len: 32}
	// we do bytes clamping here, to make sure scalar was ready to use
	clamp(mut scalar)!
	return x25519_generic(mut dst, mut scalar, point)
}

@[direct_array_access; inline]
fn x25519_generic(mut dst []u8, mut scalar []u8, point []u8) ![]u8 {
	// we dont check arrays length here, its has been checked
	// on the underlying scalar_mult routine
	if is_base_point(point) {
		scalar_base_mult(mut dst, mut scalar)!
		// check for base_point result
		if is_base_point(dst) {
			return error('dst: get base_point')
		}
	} else {
		scalar_mult(mut dst, mut scalar, point)!
		// check for zeros point result
		if is_zero_point(dst) {
			return error('bad input point: low order point')
		}
	}
	return dst
}

// scalar_base_mult performs scalar * base_point
@[direct_array_access]
fn scalar_base_mult(mut dst []u8, mut scalar []u8) ! {
	scalar_mult(mut dst, mut scalar, base_point)!
}

// scalar_mult performs scalar multiplicatipn (scalar * point) on the curve25519 curve.
// for performance reason, scalar marked as mutable.
// scalar_mult is the main routine to perform scalar multiplication
// between scalar key and point on the curve25519 curve.
@[direct_array_access; inline]
fn scalar_mult(mut dst []u8, mut scalar []u8, point []u8) ! {
	if dst.len != point_size {
		return error('bad dst length')
	}
	if scalar.len != scalar_size {
		return error('scalar.lenght != 32')
	}
	if point.len != point_size {
		return error('point.lenght != 32')
	}
	// Note: we  dont clamping scalar here, and responsible to the caller
	// to do the clamping. Its assumed scalar has been clamped.
	mut x1 := edwards25519.Element{}
	mut x2 := edwards25519.Element{}
	mut z2 := edwards25519.Element{}
	mut x3 := edwards25519.Element{}
	mut z3 := edwards25519.Element{}
	mut tmp0 := edwards25519.Element{}
	mut tmp1 := edwards25519.Element{}

	x1.set_bytes(point[..])!
	x2.one()
	x3.set(x1)
	z3.one()

	mut swap := 0
	for pos := 254; pos >= 0; pos-- {
		mut b := scalar[pos / 8] >> u32(pos & 7)
		b &= 1
		swap = swap ^ int(b)
		x2.swap(mut x3, swap)
		z2.swap(mut z3, swap)
		swap = int(b)

		tmp0.subtract(x3, z3)
		tmp1.subtract(x2, z2)
		x2.add(x2, z2)
		z2.add(x3, z3)
		z3.multiply(tmp0, x2)
		z2.multiply(z2, tmp1)
		tmp0.square(tmp1)
		tmp1.square(x2)
		x3.add(z3, z2)
		z2.subtract(z3, z2)
		x2.multiply(tmp1, tmp0)
		tmp1.subtract(tmp1, tmp0)
		z2.square(z2)

		z3.mult_32(tmp1, 121666)
		x3.square(x3)
		tmp0.add(tmp0, z3)
		z3.multiply(x1, z2)
		z2.multiply(tmp1, tmp0)
	}

	x2.swap(mut x3, swap)
	z2.swap(mut z3, swap)

	z2.invert(z2)
	x2.multiply(x2, z2)
	copy(mut dst, x2.bytes())
}

// Utility helpers
//

@[direct_array_access; inline]
fn is_zero_point(point []u8) bool {
	if point.len != point_size {
		return false
	}
	return subtle.constant_time_compare(point, zero_point) == 1
}

@[direct_array_access; inline]
fn is_base_point(point []u8) bool {
	if point.len != point_size {
		return false
	}
	return subtle.constant_time_compare(point, base_point) == 1
}

// clamp clears out some bits of seed bytes
@[direct_array_access; inline]
fn clamp(mut seed []u8) ! {
	if seed.len != scalar_size {
		return error('bad seed sizes for clamp')
	}
	// According to RFC 7748, for x25519, in order to decode 32 random bytes
	// as an integer scalar, set the three least significant bits of the first byte
	// and the most significant bit of the last to zero,
	// set the second most significant bit of the last byte to 1
	//
	seed[0] &= 248
	seed[31] &= 127
	seed[31] |= 64
}

// is_zero returns whether seed is all zeroes in constant time.
fn is_zero(seed []u8) bool {
	mut acc := u8(0)
	for b in seed {
		acc |= b
	}
	return acc == 0
}
